// Copyright Epic Games, Inc. All Rights Reserved.

#include <zencore/zencore.h>

#if ZEN_PLATFORM_WINDOWS
#	include <zencore/windows.h>
#endif

#include <zencore/assertfmt.h>
#include <zencore/blake3.h>
#include <zencore/callstack.h>
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
#include <zencore/compactbinarypackage.h>
#include <zencore/compositebuffer.h>
#include <zencore/compress.h>
#include <zencore/crypto.h>
#include <zencore/filesystem.h>
#include <zencore/intmath.h>
#include <zencore/iobuffer.h>
#include <zencore/jobqueue.h>
#include <zencore/logging.h>
#include <zencore/memory.h>
#include <zencore/mpscqueue.h>
#include <zencore/process.h>
#include <zencore/sha1.h>
#include <zencore/stats.h>
#include <zencore/stream.h>
#include <zencore/string.h>
#include <zencore/thread.h>
#include <zencore/timer.h>
#include <zencore/uid.h>
#include <zencore/workthreadpool.h>

#include <fmt/format.h>

#include <atomic>

namespace zen::assert {

void
ExecAssertFmt(const char* Filename, int LineNumber, const char* FunctionName, std::string_view Format, fmt::format_args Args)
{
	fmt::basic_memory_buffer<char, 1024> Message;
	fmt::vformat_to(fmt::appender(Message), Format, Args);
	Message.push_back('\0');

	AssertImpl::ExecAssert(Filename, LineNumber, FunctionName, Message.data());
}

}  // namespace zen::assert

namespace zen {

AssertException::AssertException(const AssertException& Rhs) noexcept : _Mybase(Rhs), _Callstack(CloneCallstack(Rhs._Callstack))
{
}

AssertException::AssertException(AssertException&& Rhs) noexcept : _Mybase(Rhs), _Callstack(Rhs._Callstack)
{
	Rhs._Callstack = nullptr;
}

AssertException::~AssertException() noexcept
{
	FreeCallstack(_Callstack);
}

AssertException&
AssertException::operator=(const AssertException& Rhs) noexcept
{
	_Mybase::operator=(Rhs);

	CallstackFrames* Callstack = CloneCallstack(Rhs._Callstack);
	std::swap(_Callstack, Callstack);
	FreeCallstack(Callstack);
	return *this;
}

std::string
AssertException::FullDescription() const noexcept
{
	if (_Callstack)
	{
		return fmt::format("'{}'\n{}", what(), CallstackToString(_Callstack));
	}
	return what();
}

AssertImpl::AssertImpl() : NextAssertImpl(nullptr)
{
	AssertImpl** WriteAssertPtr = &CurrentAssertImpl;
	while (*WriteAssertPtr)
	{
		WriteAssertPtr = &(*WriteAssertPtr)->NextAssertImpl;
	}
	*WriteAssertPtr = this;
}

AssertImpl::~AssertImpl()
{
	AssertImpl** WriteAssertPtr = &CurrentAssertImpl;
	while ((*WriteAssertPtr) != this)
	{
		ZEN_ASSERT((*WriteAssertPtr) != nullptr);
		WriteAssertPtr = &(*WriteAssertPtr)->NextAssertImpl;
	}
	*WriteAssertPtr = NextAssertImpl;
}

void
AssertImpl::ExecAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg)
{
	void*	 Frames[8];
	uint32_t FrameCount = GetCallstack(2, 8, Frames);

	CallstackFrames* Callstack = CreateCallstack(FrameCount, Frames);

	AssertImpl* AssertImpl = CurrentAssertImpl;
	while (AssertImpl)
	{
		AssertImpl->OnAssert(Filename, LineNumber, FunctionName, Msg, Callstack);
		AssertImpl = AssertImpl->NextAssertImpl;
	}
	ThrowAssertException(Filename, LineNumber, FunctionName, Msg, Callstack);
}
void
AssertImpl::OnAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg, CallstackFrames* Callstack)
{
	ZEN_UNUSED(FunctionName);

	fmt::basic_memory_buffer<char, 2048> Message;
	auto								 Appender = fmt::appender(Message);
	fmt::format_to(Appender, "{}({}): ZEN_ASSERT({})\n{}", Filename, LineNumber, Msg, CallstackToString(Callstack));
	Message.push_back('\0');

	// We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log
	ZEN_LOG(Log(), zen::logging::level::Err, "{}", Message.data());
	zen::logging::FlushLogging();
}

void
AssertImpl::ThrowAssertException(const char*	  Filename,
								 int			  LineNumber,
								 const char*	  FunctionName,
								 const char*	  Msg,
								 CallstackFrames* Callstack)
{
	ZEN_UNUSED(FunctionName);
	fmt::basic_memory_buffer<char, 2048> Message;
	auto								 Appender = fmt::appender(Message);
	fmt::format_to(Appender, "{}({}): {}", Filename, LineNumber, Msg);
	Message.push_back('\0');

	throw AssertException(Message.data(), Callstack);
}

void refcount_forcelink();

AssertImpl* AssertImpl::CurrentAssertImpl = nullptr;
AssertImpl	AssertImpl::DefaultAssertImpl;

//////////////////////////////////////////////////////////////////////////

bool
IsDebuggerPresent()
{
#if ZEN_PLATFORM_WINDOWS
	return ::IsDebuggerPresent();
#else
	return false;
#endif
}

std::optional<bool> InteractiveSessionFlag;

void
SetIsInteractiveSession(bool Value)
{
	InteractiveSessionFlag = Value;
}

bool
IsInteractiveSession()
{
	if (!InteractiveSessionFlag.has_value())
	{
#if ZEN_PLATFORM_WINDOWS
		DWORD dwSessionId = 0;
		if (ProcessIdToSessionId(GetCurrentProcessId(), &dwSessionId))
		{
			InteractiveSessionFlag = (dwSessionId != 0);
		}
		else
		{
			InteractiveSessionFlag = false;
		}
#else
		// TODO: figure out what actually makes sense here
		InteractiveSessionFlag = true;
#endif
	}

	return InteractiveSessionFlag.value();
}

//////////////////////////////////////////////////////////////////////////

static std::atomic_int	s_ApplicationExitCode{0};
static std::atomic_bool s_ApplicationExitRequested{false};

bool
IsApplicationExitRequested()
{
	return s_ApplicationExitRequested;
}

bool
RequestApplicationExit(int ExitCode)
{
	bool Expected = false;
	if (s_ApplicationExitRequested.compare_exchange_weak(Expected, true))
	{
		s_ApplicationExitCode = ExitCode;
		return true;
	}
	return false;
}

int
ApplicationExitCode()
{
	return s_ApplicationExitCode;
}

#if ZEN_WITH_TESTS
void
zencore_forcelinktests()
{
	zen::blake3_forcelink();
	zen::callstack_forcelink();
	zen::compositebuffer_forcelink();
	zen::compress_forcelink();
	zen::crypto_forcelink();
	zen::filesystem_forcelink();
	zen::jobqueue_forcelink();
	zen::intmath_forcelink();
	zen::iobuffer_forcelink();
	zen::logging_forcelink();
	zen::memory_forcelink();
	zen::mpscqueue_forcelink();
	zen::process_forcelink();
	zen::refcount_forcelink();
	zen::sha1_forcelink();
	zen::stats_forcelink();
	zen::stream_forcelink();
	zen::string_forcelink();
	zen::thread_forcelink();
	zen::timer_forcelink();
	zen::uid_forcelink();
	zen::uson_forcelink();
	zen::usonbuilder_forcelink();
	zen::usonpackage_forcelink();
	zen::cbjson_forcelink();
	zen::cbyaml_forcelink();
	zen::workthreadpool_forcelink();
}
}  // namespace zen

#	include <zencore/testing.h>

namespace zen {

TEST_SUITE_BEGIN("core.assert");

TEST_CASE("Assert.Default")
{
	bool		A		 = true;
	bool		B		 = false;
	std::string Expected = fmt::format("{}({}): {}", __FILE__, __LINE__ + 1, "A == B");
	CHECK_THROWS_WITH(ZEN_ASSERT(A == B), Expected.c_str());
}

TEST_CASE("Assert.Format")
{
	bool		A		 = true;
	bool		B		 = false;
	std::string Expected = fmt::format("{}({}): {}", __FILE__, __LINE__ + 1, "assert(A == B) failed: true == false");
	CHECK_THROWS_WITH(ZEN_ASSERT_FORMAT(A == B, "{} == {}", A, B), Expected.c_str());
}

TEST_CASE("Assert.Custom")
{
	struct MyAssertImpl : AssertImpl
	{
		virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION
		OnAssert(const char* Filename, int LineNumber, const char* FunctionName, const char* Msg, CallstackFrames* Callstack)
		{
			ZEN_UNUSED(Callstack);
			AssertFileName = Filename;
			Line		   = LineNumber;
			FuncName	   = FunctionName;
			Message		   = Msg;
		}
		AssertImpl* PrevAssertImpl = nullptr;

		const char* AssertFileName = nullptr;
		int			Line		   = -1;
		const char* FuncName	   = nullptr;
		const char* Message		   = nullptr;
	};

	MyAssertImpl MyAssert;
	bool		 A = true;
	bool		 B = false;
	CHECK_THROWS_WITH(ZEN_ASSERT(A == B), std::string(fmt::format("{}({}): {}", __FILE__, __LINE__, "A == B")).c_str());
	CHECK(MyAssert.AssertFileName != nullptr);
	CHECK(MyAssert.Line != -1);
	CHECK(MyAssert.FuncName != nullptr);
	CHECK(strcmp(MyAssert.Message, "A == B") == 0);
}

TEST_CASE("Assert.Callstack")
{
	try
	{
		ZEN_ASSERT(false);
	}
	catch (const AssertException& Assert)
	{
		ZEN_INFO("Assert failed: {}", Assert.what());
		CHECK(Assert._Callstack->FrameCount > 0);
		CHECK(Assert._Callstack->Frames != nullptr);
		ZEN_INFO("Callstack:\n{}", CallstackToString(Assert._Callstack));
	}

	WorkerThreadPool Pool(1);
	auto			 Task = Pool.EnqueueTask(std::packaged_task<int()>{[] {
		ZEN_ASSERT(false);
		return 1;
	}});

	try
	{
		(void)Task.get();
		CHECK(false);
	}
	catch (const AssertException& Assert)
	{
		ZEN_INFO("Assert in future: {}", Assert.what());
		CHECK(Assert._Callstack->FrameCount > 0);
		CHECK(Assert._Callstack->Frames != nullptr);
		ZEN_INFO("Callstack:\n{}", CallstackToString(Assert._Callstack));
	}
}

TEST_SUITE_END();
#endif

}  // namespace zen
